package test.abcode.protobuf.websocket;

import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.BinaryWebSocketHandler;
import test.abcode.protobuf.chatroom.ChatMessageProto;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
@Component
public class ChatWebSocketHandler extends BinaryWebSocketHandler {

    private final Map<String, WebSocketSession> sessionMap = new ConcurrentHashMap<>();
    private final Map<String, ChatMessageProto.UserInfo> sessionUserInMap = new ConcurrentHashMap<>();
    private final List<ChatMessageProto.UserInfo> roomUserList = new ArrayList<>();

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        // 保存用户Session，等待加入房间事件处理
        sessionMap.put(session.getId(), session);
    }

    // 处理二进制消息
    @Override
    protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
        byte[] payload = message.getPayload().array();
        ChatMessageProto.GameMessage gameMessage = ChatMessageProto.GameMessage.parseFrom(payload);

        // 根据消息类型进行处理
        switch (gameMessage.getReqType()) {
            case HEARTBEAT:
                handleHeartbeat(session, gameMessage);
                break;
            case JOIN_ROOM:
                handleJoinRoom(session, gameMessage);
                break;
            case LEAVE_ROOM:
                handleLeaveRoom(session, gameMessage);
                break;
            case SEND_MESSAGE:
                handleSendMessage(session, gameMessage);
                break;
            default:
                System.out.println("Unknown message type: " + gameMessage.getReqType());
                break;
        }
    }

    /**
     * 处理心跳消息
     *
     * @param session
     * @param gameMessage
     * @throws Exception
     */
    private void handleHeartbeat(WebSocketSession session, ChatMessageProto.GameMessage gameMessage) throws Exception {
        log.info("|handleHeartbeat|收到心跳|sessionId:{}|来自:{}", session.getId(), JsonFormat.printer().print(sessionUserInMap.get(session.getId())));
        // 创建心跳响应内容
        ChatMessageProto.HeartbeatRes heartbeatRes = ChatMessageProto.HeartbeatRes.newBuilder()
                .setSuccess(true)
                .build();

        // 推送心跳响应消息
        sendMessage(session, gameMessage.getReqType(), ChatMessageProto.MessageType.HEARTBEAT_RES, heartbeatRes.toByteString());
    }

    /**
     * 处理加入房间消息
     *
     * @param gameMessage
     */
    private void handleJoinRoom(WebSocketSession session, ChatMessageProto.GameMessage gameMessage) {
        // 解析 JoinRoom 数据并处理
        ChatMessageProto.JoinRoom joinRoom = null;
        try {
            joinRoom = ChatMessageProto.JoinRoom.parseFrom(gameMessage.getContent());
            ChatMessageProto.UserInfo user = joinRoom.getUser();
            // 添加新用户信息
            roomUserList.add(user);
            sessionUserInMap.put(session.getId(), user);

            ChatMessageProto.JoinRoomNotice joinRoomRes = ChatMessageProto.JoinRoomNotice.newBuilder()
                    .setSuccess(true)
                    .setMessage("Welcome to the chat room!")
                    .setJoinUser(user)
                    .addAllUsers(roomUserList)
                    .build();
            // 广播用户加入信息
            sendToAll(gameMessage.getReqType(), ChatMessageProto.MessageType.JOIN_ROOM_NOTICE, joinRoomRes.toByteString());
            log.info("User joined room: {}", JsonFormat.printer().print(joinRoom.getUser()));
        } catch (InvalidProtocolBufferException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 处理发送消息
     *
     * @param session
     * @param gameMessage
     */
    private void handleSendMessage(WebSocketSession session, ChatMessageProto.GameMessage gameMessage) {
        // 解析 Send 数据并处理
        ChatMessageProto.SendMsg send = null;
        try {
            send = ChatMessageProto.SendMsg.parseFrom(gameMessage.getContent());
            System.out.println("Received message: " + send.getContent());
            // 处理发送消息逻辑
            ChatMessageProto.SendMsgNotice sendMsgNotice = ChatMessageProto.SendMsgNotice.newBuilder()
                    .setSuccess(true)
                    .setContent(send.getContent())
                    .setSender(sessionUserInMap.get(session.getId()))
                    .build();
            sendToAll(gameMessage.getReqType(), ChatMessageProto.MessageType.SEND_MESSAGE_NOTICE, sendMsgNotice.toByteString());
        } catch (InvalidProtocolBufferException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 处理离开房间消息
     *
     * @param session
     * @param gameMessage
     */
    private void handleLeaveRoom(WebSocketSession session, ChatMessageProto.GameMessage gameMessage) {
        try {
            ChatMessageProto.LeaveRoom leaveRoom = ChatMessageProto.LeaveRoom.parseFrom(gameMessage.getContent());
            System.out.println("User left room: " + leaveRoom.getUserId());
            roomUserList.removeIf(user -> user.getUserId().equals(leaveRoom.getUserId()));
            // 广播用户离开信息
            ChatMessageProto.LeaveRoomNotice leaveRoomRes = ChatMessageProto.LeaveRoomNotice.newBuilder()
                    .setSuccess(true)
                    .setMessage("User left the chat room.")
                    .setLeaveUser(sessionUserInMap.get(session.getId()))
                    .addAllUsers(roomUserList)
                    .build();
            sendToAll(gameMessage.getReqType(), ChatMessageProto.MessageType.LEAVE_ROOM_NOTICE, leaveRoomRes.toByteString());

        } catch (InvalidProtocolBufferException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 推送消息给单个人
     *
     * @param session
     * @param reqType 请求类型
     * @param resType 响应类型
     * @param content 消息内容
     * @throws Exception
     */
    private void sendMessage(WebSocketSession session, ChatMessageProto.MessageType reqType, ChatMessageProto.MessageType resType, ByteString content) throws Exception {
        // 构建消息
        ChatMessageProto.GameMessage gameMessage = ChatMessageProto.GameMessage.newBuilder()
                .setReqType(reqType)
                .setResType(resType)
                .setContent(content)
                .build();

        // 序列化消息
        byte[] serializedMessage = gameMessage.toByteArray();

        // 发送消息到前端
        session.sendMessage(new BinaryMessage(serializedMessage));
    }

    /**
     * 推送消息给所有人
     *
     * @param reqType 请求类型
     * @param resType 响应类型
     * @param content
     * @throws Exception
     */
    private void sendToAll(ChatMessageProto.MessageType reqType, ChatMessageProto.MessageType resType, ByteString content) {
        ChatMessageProto.GameMessage gameMessage = ChatMessageProto.GameMessage.newBuilder()
                .setReqType(reqType)
                .setResType(resType)
                .setContent(content)
                .build();

        byte[] serializedMessage = gameMessage.toByteArray();

        // 遍历所有会话并发送消息
        for (WebSocketSession session : sessionMap.values()) {
            if (session.isOpen()) {
                // 检查会话是否仍然打开
                try {
                    session.sendMessage(new BinaryMessage(serializedMessage));
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
        sessionMap.remove(session.getId());
    }
}
